{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e71afb5e-ac05-4837-98b3-8739dcc3af4f",
   "metadata": {},
   "source": [
    "# Building an Image Recognition Neural Network\n",
    "\n",
    "The model we'll build will use the classic MNIST dataset with handwritten numbers. We'll also cover how to construct a more robust model, make predictions, and get the meta data to determine how well the model actually performs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3fecdd7c-057c-4111-9cc2-e9c695fb8215",
   "metadata": {},
   "outputs": [],
   "source": [
    "# imports \n",
    "\n",
    "import numpy\n",
    "import pandas\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow\n",
    "import tensorflow.keras\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Dense, Dropout, LeakyReLU\n",
    "\n",
    "# Setting random seeds to get reproducible results\n",
    "numpy.random.seed(0)\n",
    "tensorflow.random.set_seed(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9cf10fcd-e0a3-437e-bc0e-9ab420fddc2d",
   "metadata": {},
   "source": [
    "## Reading the data\n",
    "\n",
    "The built-in dataset is already balanced, so we don't need to worry about stratifying classes to ensure equal representation between the splits. In practice, you could use something like Sklearn's train_test_split."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "41fb4676-7cf2-4f9d-87c4-7c9767d10eda",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training set size: 60000\n",
      "Testing set size: 10000\n",
      "The label is 2\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAOhUlEQVR4nO3dfaxU9Z3H8c8XaDXyEFFuCBGyF4nGkMXSOsE1NZWVWEFNsDHBYqzUEGl8Sps0UdNNqH9oQtalSOKCwoqwSwshtkZ8yG5daCQQJQ6GRdT4sAYCCNyLRpAIlIfv/nEP7i3e+Z3LnDMP8n2/kpuZOd8593w58OHMnN/M+Zm7C8C5b0CrGwDQHIQdCIKwA0EQdiAIwg4EMaiZGxsxYoR3dnY2c5NAKDt27NCBAwesr1qhsJvZVEkLJQ2U9G/uPi/1/M7OTlWr1SKbBJBQqVRq1up+GW9mAyX9q6RpksZLmmlm4+v9fQAaq8h79kmSPnb3T9z9r5JWS5peTlsAylYk7JdI2tXr8e5s2d8wszlmVjWzand3d4HNASii4Wfj3X2Ju1fcvdLR0dHozQGooUjY90ga0+vx6GwZgDZUJOxvSbrMzMaa2Xcl/VTS2nLaAlC2uofe3P2EmT0g6b/UM/S2zN3fLa0zAKUqNM7u7q9KerWkXgA0EB+XBYIg7EAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EARhB4Ig7EAQhB0IgrADQRB2IIhCs7gCx44dS9aPHz9es7Zx48bkunv27EnWZ82alawPGsQ/794K7Q0z2yHpS0knJZ1w90oZTQEoXxn/9f2jux8o4fcAaCDeswNBFA27S/qzmW0xszl9PcHM5phZ1cyq3d3dBTcHoF5Fw36tu/9A0jRJ95vZj858grsvcfeKu1c6OjoKbg5AvQqF3d33ZLddkl6QNKmMpgCUr+6wm9lgMxt6+r6kH0vaXlZjAMpV5Gz8SEkvmNnp3/MHd//PUrpC03zxxRfJ+vz585P19evXJ+ubN28+25b6LW8cfu7cuQ3b9rdR3WF3908kfa/EXgA0EENvQBCEHQiCsANBEHYgCMIOBMF3AM8BqY8hL1y4MLluXv3IkSPJursn62PHjq1Zu/jii5PrbtmyJVl/5plnkvV77723Zi3ipzk5sgNBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEIyzt4GjR48m64899liyvnjx4pq1gwcP1tVTf02YMCFZf/3112vWTpw4kVx35MiRyfr+/fuT9dSfnXF2AOcswg4EQdiBIAg7EARhB4Ig7EAQhB0IgnH2NrBp06Zkfd68eU3q5JvGjx+frG/YsCFZHzZsWM3aZ599VldPqA9HdiAIwg4EQdiBIAg7EARhB4Ig7EAQhB0IgnH2NrB8+fKG/e7LL788Wb/++uuT9ccffzxZT42j59m5c2fd6+Ls5R7ZzWyZmXWZ2fZeyy4ys9fM7KPsdnhj2wRQVH9exi+XNPWMZY9IWuful0lalz0G0MZyw+7uGyR9fsbi6ZJWZPdXSLq13LYAlK3eE3Qj3X1vdn+fpJoXCzOzOWZWNbNqak4yAI1V+Gy898zsV3N2P3df4u4Vd69EvMgf0C7qDft+MxslSdltV3ktAWiEesO+VtKs7P4sSS+W0w6ARskdZzezVZImSxphZrsl/VbSPElrzGy2pJ2SZjSyyXPdokWLkvVrrrkmWZ869czBkv+Xd+31wYMHJ+uN1NXFC8Jmyg27u8+sUZpSci8AGoiPywJBEHYgCMIOBEHYgSAIOxAEX3FtA0OHDk3W77vvviZ10lzr169vdQuhcGQHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSAYZw/u+eefT9YPHTqUrPdcqKg2M6tZ27JlS3LdPDfffHOyfumllxb6/ecajuxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EATj7N8Cx48fT9Y//fTTmrW5c+cm1125cmVdPZ126tSpZH3AgPqPJ2PGjEnWn3vuuYZt+1zE3gCCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIBhnb4KTJ08m67t3707WJ0+enKzv2rWrZu2CCy5Irps3lj1t2rRkfdWqVcn64cOHk/WUEydOJOuvvPJKsn7HHXfUrA0cOLCunr7Nco/sZrbMzLrMbHuvZY+a2R4z25r93NTYNgEU1Z+X8cslTe1j+QJ3n5j9vFpuWwDKlht2d98g6fMm9AKggYqcoHvAzLZlL/OH13qSmc0xs6qZVbu7uwtsDkAR9YZ9saRxkiZK2itpfq0nuvsSd6+4e6Wjo6POzQEoqq6wu/t+dz/p7qckLZU0qdy2AJStrrCb2aheD38iaXut5wJoD7nj7Ga2StJkSSPMbLek30qabGYTJbmkHZJ+0bgW21/eOPrWrVuT9auvvrrQ9hctWlSzNmXKlOS648aNS9aPHDmSrG/bti1Z37x5c7Kesm/fvmT97rvvTtZT143P2+eDBp17H0HJ/RO5+8w+Fj/bgF4ANBAflwWCIOxAEIQdCIKwA0EQdiCIc298oUFSw2sLFy5MrvvQQw8V2nbqq5qSdNddd9WsnX/++cl1v/rqq2T9lltuSdbffPPNZP28886rWXviiSeS6+YNWeZdSvq6666rWZsxY0Zy3bxLcA8ZMiRZzzN69OhC69eDIzsQBGEHgiDsQBCEHQiCsANBEHYgCMIOBME4eyZv6uEnn3yyZu3hhx9Orjt06NBkffny5cn6jTfemKynxtJ37tyZXPeee+5J1jds2JCsT5gwIVlfvXp1zdoVV1yRXPfYsWPJ+oMPPpisL1u2rGZtxYoVyXXXrFmTrOdJfb1Wkj788MNCv78eHNmBIAg7EARhB4Ig7EAQhB0IgrADQRB2IAjG2TMvv/xysp4aS8/7bvNLL72UrF911VXJ+gcffJCsP/300zVrK1euTK6bd6nop556KlnP+679sGHDkvWU1HfhJenKK69M1lOfjbjtttuS6y5dujRZz7NgwYJC6zcCR3YgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCMLcvWkbq1QqXq1Wm7a9s5F3He/U9MF512bPG0c/ePBgsr59+/ZkvYjFixcn67Nnz07WBwzgeNFOKpWKqtWq9VXL/ZsyszFm9hcze8/M3jWzX2bLLzKz18zso+x2eNmNAyhPf/5bPiHp1+4+XtI/SLrfzMZLekTSOne/TNK67DGANpUbdnff6+5vZ/e/lPS+pEskTZd0+to+KyTd2qAeAZTgrN5wmVmnpO9L2ixppLvvzUr7JI2ssc4cM6uaWbW7u7tIrwAK6HfYzWyIpD9K+pW7H+pd856zfH2e6XP3Je5ecfdKR0dHoWYB1K9fYTez76gn6L939z9li/eb2aisPkpSV2NaBFCG3K+4mplJelbS++7+u16ltZJmSZqX3b7YkA6bpLOzM1lPDb0dPXo0ue6mTZvqaelrd955Z7J+ww031KxNmzYtue6FF16YrDO0du7oz/fZfyjpZ5LeMbOt2bLfqCfka8xstqSdktITXgNoqdywu/tGSX0O0kuaUm47ABqF12hAEIQdCIKwA0EQdiAIwg4EwaWkM+vWrUvW33jjjZq1vHH0UaNGJeu33357sp73FdqBAwcm64DEkR0Ig7ADQRB2IAjCDgRB2IEgCDsQBGEHgmCcPZM3PfDkyZPrqgHtgiM7EARhB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCMIOBJEbdjMbY2Z/MbP3zOxdM/tltvxRM9tjZluzn5sa3y6AevXn4hUnJP3a3d82s6GStpjZa1ltgbv/S+PaA1CW/szPvlfS3uz+l2b2vqRLGt0YgHKd1Xt2M+uU9H1Jm7NFD5jZNjNbZmbDa6wzx8yqZlbt7u4u1i2AuvU77GY2RNIfJf3K3Q9JWixpnKSJ6jnyz+9rPXdf4u4Vd690dHQU7xhAXfoVdjP7jnqC/nt3/5Mkuft+dz/p7qckLZU0qXFtAiiqP2fjTdKzkt5399/1Wt57atKfSNpefnsAytKfs/E/lPQzSe+Y2dZs2W8kzTSziZJc0g5Jv2hAfwBK0p+z8RslWR+lV8tvB0Cj8Ak6IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEObuzduYWbeknb0WjZB0oGkNnJ127a1d+5LorV5l9vZ37t7n9d+aGvZvbNys6u6VljWQ0K69tWtfEr3Vq1m98TIeCIKwA0G0OuxLWrz9lHbtrV37kuitXk3praXv2QE0T6uP7ACahLADQbQk7GY21cw+MLOPzeyRVvRQi5ntMLN3smmoqy3uZZmZdZnZ9l7LLjKz18zso+y2zzn2WtRbW0zjnZhmvKX7rtXTnzf9PbuZDZT0oaQbJO2W9Jakme7+XlMbqcHMdkiquHvLP4BhZj+SdFjSv7v732fL/lnS5+4+L/uPcri7P9wmvT0q6XCrp/HOZisa1XuacUm3Svq5WrjvEn3NUBP2WyuO7JMkfezun7j7XyWtljS9BX20PXffIOnzMxZPl7Qiu79CPf9Ymq5Gb23B3fe6+9vZ/S8lnZ5mvKX7LtFXU7Qi7JdI2tXr8W6113zvLunPZrbFzOa0upk+jHT3vdn9fZJGtrKZPuRO491MZ0wz3jb7rp7pz4viBN03XevuP5A0TdL92cvVtuQ978Haaey0X9N4N0sf04x/rZX7rt7pz4tqRdj3SBrT6/HobFlbcPc92W2XpBfUflNR7z89g25229Xifr7WTtN49zXNuNpg37Vy+vNWhP0tSZeZ2Vgz+66kn0pa24I+vsHMBmcnTmRmgyX9WO03FfVaSbOy+7MkvdjCXv5Gu0zjXWuacbV437V8+nN3b/qPpJvUc0b+fyX9Uyt6qNHXpZL+J/t5t9W9SVqlnpd1x9VzbmO2pIslrZP0kaT/lnRRG/X2H5LekbRNPcEa1aLerlXPS/RtkrZmPze1et8l+mrKfuPjskAQnKADgiDsQBCEHQiCsANBEHYgCMIOBEHYgSD+DxCyZWhxyt+xAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAABG0AAADiCAYAAAD0xzrZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAaH0lEQVR4nO3deZRdVdkn4L0TyMxgIATwCxQCARTpgGFQQ3cEZFwqgwwqCIIocWCQIQqLQQICumgGgyIRiIJL0YjY2LSgiAyKiiykMYwKCSFZgaQh+iUShNTpP6rww3jfY3LrVt1dqedZqxbJ/tU59z0kO/fWW6fum6uqSgAAAACUZVC7CwAAAADgX2naAAAAABRI0wYAAACgQJo2AAAAAAXStAEAAAAokKYNAAAAQIE0bQqSc/5lzvnjfX0sUM/ehDLZm1AmexPKZG/2T5o2vSTnPCfnvFe760gppZzz5JxzZ8556Rs+jm53XdAOJe3NlFLKOX845zw357ws53xLznl0u2uCdihtb74u53xdzrnKOW/V7lqgHUramznnTXLO/yvnvKB7X3a0uyZol8L2Zs45n5Vzfjbn/Nec8/dyzuu2u641habNwLGgqqpRb/j4VrsLgoEu5/y2lNI3UkpHpZTGppT+llL6WluLAv4h5zwppbRlu+sA/qEzpfTTlNIh7S4E+CcfTV2vZ9+dUto0pTQ8pfTVtla0BtG06UM55zflnH+Sc16Uc36p+9f/sdKnbZlz/l13h/LHb/yue855t5zzr3POS3LOD+ecJ/fpBcAaqo178yMppVurqrqnqqqlKaWzU0oH55zXacmFQT/XzufNnPNaqesF52dbcjGwBmnX3qyq6vmqqr6WUnqgdVcDa442Pm++L6V0bVVV87pf016SUjo85zyiJRc2wGna9K1BKaXrU0qbp5Q2Sym9nFKavtLnfDSldGxKaZOU0msppStTSinn/OaU0v9OKV2QUhqdUjotpfTDnPOYlR8k57xZ90bb7A3LG+Wcn885P5NzviznPLK1lwb9Wrv25ttSSg+/nldV9eeU0t9TSuNbdmXQv7XzefOUlNI9VVX935ZeEawZ2rk3gVg792Ze6ddDU0pbt+CaBjxNmz5UVdX/q6rqh1VV/a2qqv9MKV2YUvofK33aDVVV/bGqqmWp67vuh+WcB6eUjkwp3VZV1W1VVXVWVfWzlNLvU0r7N3icZ6uqWr+qqme7lx5PKU1IXRtzj5TSO1JK/7M3rhH6ozbuzVEppb+s9Gl/SSm50wZS+/ZmznlcSumTKaVzevHyoN9q4/MmUKONe/OnKaWP55w7cs7rpZSmdq+706YFNG36UM55RM75G7nrTUf/mlK6J6W0fvcmed28N/x6bkpp7ZTShqmrW3pod0dzSc55SUppUupqxNSqqmphVVWPdm++Z1JKZyQ/Cwz/0K69mVJamlJa+U3a1k0p/WeTlwJrlDbuzctTSudXVbVyUxVIbd2bQI027s3rUkrfTSn9MqU0O6V0V/f6cz25Hrpo2vStU1NK26SUdq2qat2U0n/vXn/jrWTj3vDrzVJKr6aUFqeuzXVDd0fz9Y+RVVVd3EQdVfJnD2/Urr05O6X0317/Tc75LanrVtInm78UWKO0a2/umVL6Ss55Yc55Yffa/TnnD/foamDNUcprWuCftWVvdt8ccG5VVR1VVf1H6nqNO7/7gx7yhXvvWjvnPOz1j5TSm1LXzxUu6X7Dp3MbHHNkzvmt3W/adH5KaVZVVStSSjemlN6Xc94n5zy4+5yTG7yx1L/IOb8n57x57jIupXRxSunHLbtK6H+K2Jsppe90H7t79/tMnZ9Surn7dlYYiErZm+NTV0N1QvdHSl1vsvijHl0d9F+l7M3U/fhDu387tPv3MFAVsTdzzqNzzlt2f7351tT1VhznV1XV2bIrHcA0bXrXbalr07z+sX7qGn+2OKX0m9T1s38ruyGlNDOltDClNCyldGJKKVVVNS+l9IGU0pkppUWpqxN6emrwZ5i73hhqaf6vN4baMaX065TSsu7/PvL6eWGAKmJvVlU1O6V0Qupq3ryQut7L5lOtuUTol0rZmy90/2jxwqqqXr/TZnFVVS+35jKh3ylib3Z7OXX9eHFKXe/baF8ykJWyNzfsrmVZSun/pJSuq6rqmlZcICnlqqraXQMAAAAAK3GnDQAAAECBNG0AAAAACqRpAwAAAFAgTRsAAACAAmnaAAAAABRordX55A033LDq6OjopVKgbHPmzEmLFy/O7a6jEXuTgazUvWlfMtA9+OCDi6uqGtPuOlZmbzLQ2ZtQpmhvrlbTpqOjI/3+979vXVXQj0ycOLHdJYTsTQayUvemfclAl3Oe2+4aGrE3GejsTShTtDf9eBQAAABAgTRtAAAAAAqkaQMAAABQIE0bAAAAgAJp2gAAAAAUSNMGAAAAoECaNgAAAAAF0rQBAAAAKJCmDQAAAECBNG0AAAAACqRpAwAAAFAgTRsAAACAAmnaAAAAABRI0wYAAACgQJo2AAAAAAXStAEAAAAokKYNAAAAQIE0bQAAAAAKpGkDAAAAUCBNGwAAAIACadoAAAAAFEjTBgAAAKBAmjYAAAAABdK0AQAAACiQpg0AAABAgdZqdwEAa7J58+aF2RVXXBFml112WZidcsopYXbSSSeF2bhx48IMAAAojzttAAAAAAqkaQMAAABQIE0bAAAAgAJp2gAAAAAUSNMGAAAAoECmR7VJZ2dnmL3yyistfaxvfetbYbZs2bIwe/TRR8Ps8ssvD7MzzzwzzKZPnx5mw4cPD7NLL7204fqUKVPCY6CvzJ8/P8x23HHHMFuyZEmY5ZzDrG7/1e33RYsWhRnQHo899liY7bXXXmH2hz/8IczGjBnTk5JgjTJjxowwO+GEE8Ks7rX6E088EWbjx49ftcIAVpE7bQAAAAAKpGkDAAAAUCBNGwAAAIACadoAAAAAFEjTBgAAAKBAmjYAAAAABTLyu9tf/vKXMFuxYkWYPfzww2F2xx13hFndqN9rrrkmzPpSR0dHmJ166qlhdu2114bZeuutF2a77757mO2xxx5hBn1l7ty5DdcnT54cHvPSSy+FWd1Y77q9MnTo0DB74YUXwuzpp58Os8033zzMBg8eHGaU46mnnmq4Xvd3cJdddumtclhFv/3tb8Nszz337MNKoP+68847w+xzn/tcmA0a1Nz3r+uevwFazZ02AAAAAAXStAEAAAAokKYNAAAAQIE0bQAAAAAKpGkDAAAAUCBNGwAAAIACDaiR388991yYTZgwIczqxqX2d3WjDutGdw8fPjzMjjvuuDDbaKONwmzUqFFhNmbMmDCD1fXqq6+GWTTWO6WU9t1334br8+bN63FNK6v7N+nCCy8Ms0mTJoXZ1ltvHWbXXHNNmNXtacoRjbx9/PHHw2OM/O4bVVWFWTSqPaWUnnzyyd4oB9Y4dXtl+fLlfVgJlGHOnDlhNnPmzIbrP/3pT8NjHnjggabq+M53vhNm48aNC7Of/exnYXbMMceEWUdHx6qU1e+40wYAAACgQJo2AAAAAAXStAEAAAAokKYNAAAAQIE0bQAAAAAKpGkDAAAAUKABNfJ7gw02CLOxY8eGWSkjv/fee+8wq7u2m2++OcyGDh0aZpMnT16luqC/Of3008Ns+vTpfVhJ7O677w6zZcuWhdlBBx0UZnX/Fjz00EOrVhjFuvLKKxuu1z130DeWLl0aZhdddFGYnXTSSWE2ZsyYHtUE/c2jjz4aZuedd15T59xpp53C7I477gizkSNHNvV40Eq/+tWvwuywww4Ls+eff77helVV4TEHH3xwmM2bNy/MjjzyyDCrU1fLokWLwuyqq65q6vFK504bAAAAgAJp2gAAAAAUSNMGAAAAoECaNgAAAAAF0rQBAAAAKJCmDQAAAECBBtTI7+HDh4fZzJkzw2zWrFlh9s53vjPMDjnkkFWqa2WTJk1quP7jH/84PGbIkCFhtnDhwjC74oorVr0w6Efqxg/eeOONYVY3YjBSN2a77t+BujGI48aNC7PtttsuzKZOnRpmdf+WNXPdlGXFihXtLoHACSec0NRxdXsd1kR/+tOfwmz//fcPsxdffLGpx7v44ovDbL311mvqnLC6Ojs7w2zOnDlhdsABB4TZ0qVLw+zAAw9suH7BBReEx2y99dZhVvf649hjjw2z733ve2FW513veldTx/Vn7rQBAAAAKJCmDQAAAECBNG0AAAAACqRpAwAAAFAgTRsAAACAAmnaAAAAABRoQI38rrPzzjuH2Q477BBmdaO2zzjjjDD78pe/HGbTpk1b7ceqs/HGG4fZRRdd1NQ5oQTz588Psx133DHMlixZEmY55zD7yEc+0nB9xowZ4TGPPvpomNUdd8QRR4TZiBEjwmzTTTcNs0GD4j79DTfcEGaf//znw6xuNDmtt2DBgjCr2w+0V7PjiN/73ve2uBIo2ze/+c0wmzdvXlPnPPjgg8PsPe95T1PnhFa66667wmyfffZp6pyHH354mF133XUN14cOHdrUY913331h1uxY746OjjA76KCDmjpnf+ZOGwAAAIACadoAAAAAFEjTBgAAAKBAmjYAAAAABdK0AQAAACiQpg0AAABAgYz8XgXNjj9705ve1NRxV155ZcP13XffPTymbkwx9GeLFy8Os0suuSTMXnrppTAbO3ZsmG2xxRZhNmXKlIbrQ4YMCY+ZMGFCU1lf+9vf/hZmX/nKV8Is+veK3nHHHXeEWd2fIb1v2bJlYfbII480dc4NNtig2XKgWM0+3wwaFH+vuW6vTJs2bdUKg15U93rplFNOCbO6r/HOOeecMJs6dWqYNfu1beTkk09u6flSSummm24KsxEjRrT88UrnThsAAACAAmnaAAAAABRI0wYAAACgQJo2AAAAAAXStAEAAAAokKYNAAAAQIGM/O5FdePPfve734XZj370o4brs2fPDo/ZfvvtV7kuKM1rr70WZqeddlqY3XjjjWG23nrrhdntt98eZltttVWYvfrqq2G2JnvmmWfaXQLd/vjHP672MSWNll+TnXXWWWG2YMGCMNthhx3CbMiQIT2qCdppyZIlDdc/8IEPtPyxzjvvvDDbdtttW/540MjVV18dZnVjvetGcB9xxBFh9oUvfCHM1l577TCL1L0ef/jhh8PsqaeeCrOqqsKsbgz6xIkTw2wgcqcNAAAAQIE0bQAAAAAKpGkDAAAAUCBNGwAAAIACadoAAAAAFEjTBgAAAKBARn73orpRnddcc02Y3XnnnQ3X60YkHnjggWH27ne/O8wOOuigMMs5hxm00rPPPhtmdWO96/zmN78Js/Hjxzd1zuHDhzd1HLTTrrvu2u4SivPKK6+E2YMPPhhmdc/dN910U1O11I08HTZsWFPnhBLce++9Ddd//etfN3W+Qw89NMyOOeaYps4Jq2v58uVhNm3atDCr+7qqbqz3ddddt2qFrYYXX3yx4frhhx8eHnPXXXc19Vif/OQnw+z4449v6pwDkTttAAAAAAqkaQMAAABQIE0bAAAAgAJp2gAAAAAUSNMGAAAAoECmR7XJ6NGjw+z2229vuL7vvvuGx1x++eVNZXXvSH7IIYeE2ahRo8IMVtenP/3pMKuqKszqpp81OyFqTdbZ2RlmgwbFPfy6PwPKt2TJkj59vAULFoRZ3d/Bu+++O8yeeeaZMPv73//ecP2rX/1qeMyKFSvCbOTIkWG29957h1ndpKdXX301zLbbbrswg9I98MADYXb00Uev9vne9773hdmMGTPCzKQ1+krd88fzzz/f1Dkvu+yyMFu2bFmYzZo1K8zqJhref//9Ddf/+te/hsfUTb+qyz7+8Y+HWd2kZf6ZO20AAAAACqRpAwAAAFAgTRsAAACAAmnaAAAAABRI0wYAAACgQJo2AAAAAAUy8rtAu+yyS8P12bNnh8eccsopYfaDH/wgzI499tgw+/Of/xxmp59+epits846YcbA9dBDD4XZPffcE2Z1YwQPPfTQHtU00NSN9a77/zxx4sTeKIcmjBgxIsyiP8P3v//94THbbLNNj2taWTRKNKX68fFrrRW/JBk1alSY7brrrg3XTzvttPCY3XffPcwmTJgQZnXjwMeNGxdmdSNbx4wZE2ZQgiVLloTZbrvt1tLH2mqrrcKsbv9BXxk8eHCYbbzxxmG2cOHCMBs9enSY1b0+a9Zmm23WcH399dcPj5k3b16YjR07Nsx22mmnVa6LmDttAAAAAAqkaQMAAABQIE0bAAAAgAJp2gAAAAAUSNMGAAAAoECaNgAAAAAFMvK7H9lkk03CbObMmWF2wgknhNlee+0VZhdeeGGYPfHEE2F20003hRkD1/Lly8PslVdeCbNNN900zA444IAe1dRfvfbaa2F25ZVXNnXOD37wg2F25plnNnVOWu/8888Psy233LLh+i9/+cteqqaxrbfeOsw+/OEPh1ndqN8tttiiRzW1ym233RZmdeNct912294oB/rEpZdeGmaDBrX2+79Tp05t6fmg1YYNGxZm9913X5jttttuYbZo0aIwe+tb3xpmRx11VJh99KMfDbORI0eu9vnqRn5PmTIlzGgNd9oAAAAAFEjTBgAAAKBAmjYAAAAABdK0AQAAACiQpg0AAABAgTRtAAAAAApk5Pcaom783OTJk8Ns8ODBYVY3VviWW24Js7px4Ntss02YQSN1f7dHjRrVh5X0rbr99/Wvfz3MzjjjjDDr6OgIs7POOivMhgwZEmaU4+ijj16tdVbfT37yk6aOO/bYY1tcCbTW/Pnzw2zWrFktfayPfexjYTZmzJiWPhb0pbrXWQsXLuy7Qv6Np556quF63dd3gwbF93psu+22PS2Jf8OdNgAAAAAF0rQBAAAAKJCmDQAAAECBNG0AAAAACqRpAwAAAFAgTRsAAACAAhn53Y8sWLAgzG6++eYwu//++8OsbqxwnZ133jnMxo8f39Q5oZGjjjqq3SX0mroRq5dcckmYfe1rXwuzulGqM2bMWLXCgJY6+OCD210C1Jo4cWKYLV68uKlz7rPPPg3Xp0+f3tT5gNZYvnx5w/W6sd455zDbb7/9elwT9dxpAwAAAFAgTRsAAACAAmnaAAAAABRI0wYAAACgQJo2AAAAAAXStAEAAAAokJHfbbJo0aIwu+qqqxquX3/99eExzz33XI9rWtngwYPDrKOjI8zqRsIxcFVV1VQ2c+bMMDv77LN7UlKf+O53vxtmn/3sZ8PspZdeCrMTTzwxzC677LJVKwwAur3wwgthVjcGuM7UqVMbrg8ZMqSp8wGt8fa3v73dJbCa3GkDAAAAUCBNGwAAAIACadoAAAAAFEjTBgAAAKBAmjYAAAAABdK0AQAAACiQkd89tHTp0jC79dZbw+z8888PsyeffLJHNa2OPfbYI8wuvvjiMHvHO97RG+WwBqsbBV+X1Y2zr9tHxx13XJits846YTZ79uww+8Y3vtFw/d577w2PmTNnTphtueWWYXbEEUeEWd3Ib6A9qqoKs7lz54bZW97ylt4oB/7FaaedFmadnZ0tf7wddtih5ecEeu6RRx5pdwmsJnfaAAAAABRI0wYAAACgQJo2AAAAAAXStAEAAAAokKYNAAAAQIE0bQAAAAAKZOR3t2XLloXZvHnzwuzII48Ms4ceeqhHNa2OvffeO8y++MUvhtnOO+8cZnVjmKGvrFixIszqRn5fe+21YTZ69Ogwa/UYxP322y/M9t133zD7zGc+09I6gN5V95zZG+OUoZH58+eH2axZs8Js0KD4+7hDhw4Ns3PPPTfMRo4cGWZA+zz99NPtLoHV5E4bAAAAgAJp2gAAAAAUSNMGAAAAoECaNgAAAAAF0rQBAAAAKJCmDQAAAECB1riR3y+//HKYnXzyyWF23333hdnjjz/ek5JW2/77799w/ZxzzgmPmTBhQpitvfbaPS0Jeuxtb3tbmO21115h9vOf/7ypx3vuuefCrG4kap2NNtqo4fqUKVPCY84+++ymHgtYc/ziF78Isz333LMPK2FNt3Tp0jBr9rmvo6MjzKZOndrUOYH22WWXXRqud3Z2hscMGuRej3byfx8AAACgQJo2AAAAAAXStAEAAAAokKYNAAAAQIE0bQAAAAAKpGkDAAAAUKCiR37PmTOn4fqXvvSl8Ji68cBz587taUmrZcSIEWE2bdq0MPvUpz7VcH3IkCE9rgnaZd111w2zWbNmhdm3v/3tMDvxxBN7VFMjF1xwQZgdf/zxDdc32GCDltcB9C9VVbW7BAD4tzbZZJOG69tvv314zGOPPRZmzz//fJhtscUWq14YIXfaAAAAABRI0wYAAACgQJo2AAAAAAXStAEAAAAokKYNAAAAQIGKnh71wx/+sOH6tdde2/LH2mmnncLsQx/6UJittVb8v/ATn/hEmA0bNmzVCoMBYNSoUWEWTVP7dxlAqx1yyCFhdvXVV/dhJdDYm9/85jA74IADwuzWW2/tjXKAfuTyyy8Ps3322SfMzjjjjDCbPn16mI0dO3aV6sKdNgAAAABF0rQBAAAAKJCmDQAAAECBNG0AAAAACqRpAwAAAFAgTRsAAACAAhU98vvUU09drXUAgN6y5557hllnZ2cfVgKNjRo1KsxuueWWvisE6HcmTZoUZocddliYff/73w+zDTfcMMyuuOKKMBsyZEiYDUTutAEAAAAokKYNAAAAQIE0bQAAAAAKpGkDAAAAUCBNGwAAAIACadoAAAAAFKjokd8AAABA7xo6dGiYXX/99WG2zTbbhNm0adPC7LzzzguzsWPHhtlA5E4bAAAAgAJp2gAAAAAUSNMGAAAAoECaNgAAAAAF0rQBAAAAKJCmDQAAAECBjPwGAAAAGqobB37uuec2lbHq3GkDAAAAUCBNGwAAAIACadoAAAAAFEjTBgAAAKBAmjYAAAAABdK0AQAAAChQrqpq1T8550Uppbm9Vw4UbfOqqsa0u4hG7E0GuCL3pn0J9iYUyt6EMjXcm6vVtAEAAACgb/jxKAAAAIACadoAAAAAFEjTBgAAAKBAmjYAAAAABdK0AQAAACiQpg0AAABAgTRtAAAAAAqkaQMAAABQIE0bAAAAgAL9f6FBXveL/EtXAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 1440x1440 with 5 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "(x_train, y_train), (x_test, y_test) = tensorflow.keras.datasets.mnist.load_data()\n",
    "print(\"Training set size:\", len(x_train))\n",
    "print(\"Testing set size:\", len(x_test))\n",
    "\n",
    "plt.imshow(x_train[5], cmap='Greys')\n",
    "print(\"The label is\", y_train[5])\n",
    "\n",
    "# Plotting some examples of numbers\n",
    "fig = plt.figure(figsize=(20,20))\n",
    "for i in range(5):\n",
    "    ax = fig.add_subplot(1, 5, i+1, xticks=[], yticks=[])\n",
    "    ax.imshow(x_train[i], cmap='Greys')\n",
    "    ax.set_title('Label:' + str(y_train[i]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79daaaf9-12a2-4cf1-aa12-6440b4c95413",
   "metadata": {},
   "source": [
    "## Pre-processing the data\n",
    "\n",
    "We'll have to reshape the data to make it train faster, and we'll also use the pandas.factorize to ensure we get the classes despite them being synonymous with numbers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c5ae6110-15c3-4d76-92b9-349aad4ef8ff",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(60000, 784)\n",
      "(60000,)\n",
      "10\n"
     ]
    }
   ],
   "source": [
    "# Reshaping the features\n",
    "num_columns = 28*28\n",
    "x_train_reshaped = x_train.reshape(-1, num_columns) # -1 is a placeholder for the dataset's size\n",
    "x_test_reshaped = x_test.reshape(-1, num_columns)\n",
    "\n",
    "# This isn't necessary since numbers are coincidentally the same as the class labels, but it's a good practice.\n",
    "y_train_cat, y_train_labels = pandas.factorize(y_train, sort=True) # returns 1D array\n",
    "y_test_cat, y_test_labels = pandas.factorize(y_test, sort=True) # sort to ensure class labels and numbers align\n",
    "\n",
    "# The test and train are balanced so they should have the same number of classes\n",
    "num_labels = len(set( list(y_test_labels) + list(y_test_labels))) # Add lists then use set to get unique values\n",
    "assert num_labels == 10\n",
    "\n",
    "print(x_train_reshaped.shape)\n",
    "print(y_train_cat.shape) \n",
    "print(num_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba55e35f-5126-493e-b2e4-457f18010059",
   "metadata": {},
   "source": [
    "## Model Structure\n",
    "\n",
    "We'll use a sequential model to add a few dense layers with the LeakyReLu activation function. ReLu is good, but it's prone to zeroing out. We'll once againt add dropout and set our unit sizes as powers of 2. We also use sparse_categorical_crossentropy since our labels are 1D."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4be13fb7-3cc2-40cb-86f7-f96585db7703",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " dense (Dense)               (None, 128)               100480    \n",
      "                                                                 \n",
      " leaky_re_lu (LeakyReLU)     (None, 128)               0         \n",
      "                                                                 \n",
      " dropout (Dropout)           (None, 128)               0         \n",
      "                                                                 \n",
      " dense_1 (Dense)             (None, 64)                8256      \n",
      "                                                                 \n",
      " leaky_re_lu_1 (LeakyReLU)   (None, 64)                0         \n",
      "                                                                 \n",
      " dropout_1 (Dropout)         (None, 64)                0         \n",
      "                                                                 \n",
      " dense_2 (Dense)             (None, 10)                650       \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 109,386\n",
      "Trainable params: 109,386\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "# Building the model\n",
    "num_units_penultimate = 2**6\n",
    "model = Sequential()\n",
    "model.add(Dense(2**7, input_shape=(num_columns,)))\n",
    "model.add(LeakyReLU(alpha = 0.01))\n",
    "model.add(Dropout(.2))\n",
    "model.add(Dense(num_units_penultimate))\n",
    "model.add(LeakyReLU(alpha = 0.01))\n",
    "model.add(Dropout(.2))\n",
    "model.add(Dense(num_labels, activation='softmax')) # two or more classes\n",
    "\n",
    "# Compiling the model\n",
    "model.compile(\n",
    "    loss = 'sparse_categorical_crossentropy',\n",
    "    optimizer='adam',\n",
    "    metrics=['accuracy']\n",
    ")\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1926762-899b-4c29-b342-21fa46c941d0",
   "metadata": {},
   "source": [
    "## Training the Model\n",
    "It's a good practice to set your batch size equal to the number of penultimate neurons in your layer. Training accuracy will vary due to the stochastic nature of the model, but you can clearly see the loss going down and the accuracy going up. We'll save the model to the history instance so we can get the metadata."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "551f675e-2a46-466d-9442-31f21edb1205",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "938/938 [==============================] - 3s 3ms/step - loss: 3.1994 - accuracy: 0.6949\n",
      "Epoch 2/10\n",
      "938/938 [==============================] - 3s 3ms/step - loss: 0.6405 - accuracy: 0.8247\n",
      "Epoch 3/10\n",
      "938/938 [==============================] - 3s 3ms/step - loss: 0.4863 - accuracy: 0.8638\n",
      "Epoch 4/10\n",
      "938/938 [==============================] - 3s 3ms/step - loss: 0.4073 - accuracy: 0.8860\n",
      "Epoch 5/10\n",
      "938/938 [==============================] - 3s 3ms/step - loss: 0.3610 - accuracy: 0.9014\n",
      "Epoch 6/10\n",
      "938/938 [==============================] - 3s 4ms/step - loss: 0.3193 - accuracy: 0.9100\n",
      "Epoch 7/10\n",
      "938/938 [==============================] - 3s 3ms/step - loss: 0.2835 - accuracy: 0.9211\n",
      "Epoch 8/10\n",
      "938/938 [==============================] - 3s 3ms/step - loss: 0.2641 - accuracy: 0.9273\n",
      "Epoch 9/10\n",
      "938/938 [==============================] - 3s 3ms/step - loss: 0.2521 - accuracy: 0.9296\n",
      "Epoch 10/10\n",
      "938/938 [==============================] - 4s 4ms/step - loss: 0.2288 - accuracy: 0.9354\n"
     ]
    }
   ],
   "source": [
    "history = model.fit(x_train_reshaped, y_train_cat, epochs=10, batch_size=num_units_penultimate, verbose = 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ece2e5d9-f714-458a-b17a-398ecf37ade3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'validation_data': None,\n",
       " 'model': <keras.engine.sequential.Sequential at 0x1da4b072c10>,\n",
       " '_chief_worker_only': None,\n",
       " '_supports_tf_logs': False,\n",
       " 'history': {'loss': [3.199434995651245,\n",
       "   0.6405399441719055,\n",
       "   0.486316978931427,\n",
       "   0.4073334038257599,\n",
       "   0.3610023856163025,\n",
       "   0.31927186250686646,\n",
       "   0.2834741473197937,\n",
       "   0.2640761733055115,\n",
       "   0.2520712912082672,\n",
       "   0.22878681123256683],\n",
       "  'accuracy': [0.6948999762535095,\n",
       "   0.8246999979019165,\n",
       "   0.8637833595275879,\n",
       "   0.8860499858856201,\n",
       "   0.9014333486557007,\n",
       "   0.9099500179290771,\n",
       "   0.9210666418075562,\n",
       "   0.927299976348877,\n",
       "   0.9296333193778992,\n",
       "   0.9353833198547363]},\n",
       " 'params': {'verbose': 1, 'epochs': 10, 'steps': 938},\n",
       " 'epoch': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "history.__dict__ # Neat trick to get class attributes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0e6cdc3-3d6a-49e4-bcff-f11b5b8a8f2e",
   "metadata": {},
   "source": [
    "## Making predictions\n",
    "\n",
    "We'll have the trained model return a matrix of predictions and use argmax to get the index with the highest probability to get the prediction in each row. For example, if there are 3 classes and the matrix has a row like [0.4, 0.1, 0.5], then the third class (index 2) will be the class our softmax produced."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "af4b399f-809b-4470-a2de-ef3ea9f7fd07",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "313/313 [==============================] - 1s 2ms/step\n",
      "(10000, 10)\n",
      "Matrix row of predictions: [1.58623618e-24 2.82618416e-18 1.09072137e-12 2.88809243e-08\n",
      " 2.28177921e-29 4.77224047e-24 1.09657014e-35 1.00000000e+00\n",
      " 4.32105779e-24 3.43912225e-11]\n"
     ]
    }
   ],
   "source": [
    "matrix_predictions = model.predict(x_test_reshaped)\n",
    "print(matrix_predictions.shape) # the columns equals the number of classes\n",
    "print(\"Matrix row of predictions:\", matrix_predictions[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2d8728d5-e9f8-4414-957f-cbb301315a77",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions = [numpy.argmax(pred) for pred in matrix_predictions]\n",
    "array_correct = predictions == y_test_cat # numpy arrays allow point-wise comparisons, and this is a boolean 1D vector"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89f44fbf-1291-4585-afa5-1a83ca5ac761",
   "metadata": {},
   "source": [
    "Sometimes we'll get it right"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "63769012-7a9a-4d1a-9932-038b8a8a2fa2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOsAAADrCAYAAACICmHVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAFxElEQVR4nO3bwYuNexzH8TPXhLOi4RYpDisnpZQyydrGysJiWCh7GxvZocSSP0AUe2QlGzULlMVsLDQZc7NRJkXJQubc9e16vmPOGcf5zHm9th/PM8/Cu9/Ur5no9XotYPT99ac/APg1YoUQYoUQYoUQYoUQYoUQk6v5x9u3b+91Op3f9CnA4uJia2lpaeJn26pi7XQ6rVevXq3NVwH/c/jw4cbNr8EQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQQqwQYvJPf8CwvHjxonG7detW+eyuXbvKvd1ul/vZs2fLfWpqqq+N8eJkhRBihRBihRBihRBihRBihRBihRBjc89a3XXOz8//1p997dq1ct+yZUvjNj09vdafE6PT6TRuly5dKp/dvXv3Gn/Nn+dkhRBihRBihRBihRBihRBihRBihRBjc8/68OHDxm1ubq589sCBA+X++vXrcn/58mW5P3r0qHF78uRJ+ezevXvL/d27d+U+iMnJ+r/Pzp07y/39+/d9/+zqDrbVarUuXrzY97tHlZMVQogVQogVQogVQogVQogVQogVQozNPWu32+1r+xUHDx4s95mZmXK/ceNG47a4uFg+u9I968LCQrkPYuPGjeW+0j3rSt/+8ePHxm3//v3ls+uRkxVCiBVCiBVCiBVCiBVCiBVCiBVCjM096yjbvHlz4zbofeKgd8iDWOnveJeWlsr9yJEjjdvx48f7+qZkTlYIIVYIIVYIIVYIIVYIIVYI4eqGvn39+rXcT548We7Ly8vlfvPmzcat3W6Xz65HTlYIIVYIIVYIIVYIIVYIIVYIIVYI4Z6Vvt29e7fcP3z4UO7btm0r9z179qz2k9Y1JyuEECuEECuEECuEECuEECuEECuEcM9K6e3bt43bhQsXBnr38+fPy33Hjh0DvX+9cbJCCLFCCLFCCLFCCLFCCLFCCLFCCPeslB4/fty4ff/+vXz21KlT5b5v376+vmlcOVkhhFghhFghhFghhFghhFghhFghhHvWMbfSXemDBw8at02bNpXPXr9+vdw3bNhQ7vyXkxVCiBVCiBVCiBVCiBVCiBVCuLoZc7dv3y732dnZxu306dPls/4Ebm05WSGEWCGEWCGEWCGEWCGEWCGEWCGEe9Z1bm5urtzPnz9f7lu3bm3crl692scX0S8nK4QQK4QQK4QQK4QQK4QQK4QQK4Rwzxru27dv5T4zM1PuP378KPczZ840bv5edbicrBBCrBBCrBBCrBBCrBBCrBBCrBDCPeuIW15eLvcTJ06U+5s3b8q92+2W+5UrV8qd4XGyQgixQgixQgixQgixQgixQghXNyPu06dP5f7s2bOB3n/v3r1yn5qaGuj9rB0nK4QQK4QQK4QQK4QQK4QQK4QQK4RwzzoCPn/+3LhNT08P9O779++X+6FDhwZ6P8PjZIUQYoUQYoUQYoUQYoUQYoUQYoUQ7llHwJ07dxq3hYWFgd597Nixcp+YmBjo/QyPkxVCiBVCiBVCiBVCiBVCiBVCiBVCuGcdgvn5+XK/fPnycD6EaE5WCCFWCCFWCCFWCCFWCCFWCCFWCOGedQhmZ2fL/cuXL32/u9vtlnu73e773YwWJyuEECuEECuEECuEECuEECuEcHUz4o4ePVruT58+LXdXN+uHkxVCiBVCiBVCiBVCiBVCiBVCiBVCuGcdgnPnzg20Q6vlZIUYYoUQYoUQYoUQYoUQYoUQYoUQE71e79f/8cTEx1ar9c/v+xwYe3t6vd7fPxtWFSvw5/g1GEKIFUKIFUKIFUKIFUKIFUKIFUKIFUKIFUL8Cw+9sldSgzxgAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The label is 7\n",
      "The prediction is 7\n"
     ]
    }
   ],
   "source": [
    "index_correct = numpy.where(array_correct == True)[0][0] # Where returns tuple\n",
    "plt.imshow(x_test[index_correct], cmap='Greys')\n",
    "plt.xticks([])\n",
    "plt.yticks([])\n",
    "plt.show()\n",
    "print(\"The label is\", y_test_cat[index_correct])\n",
    "print(\"The prediction is\", predictions[index_correct])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8b6ed48-b890-4f9c-bd4d-9daca85c9574",
   "metadata": {},
   "source": [
    "Sometimes we'll get it wrong"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "67e9a70b-c2b5-496e-a61f-0ca79d9899cd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOsAAADrCAYAAACICmHVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAGr0lEQVR4nO3dTYjN/wLH8XOYkZLH3EyiGaOUlDyMjYVGNhaUnVKWTJSVh6JZEykbDwmb2ViIjfVgI9IsRFLTJDcLizspFBrGuatbt3ud75kZ8/Q583ptP37mJ//3/6u+nZlqrVarALPfvJl+AWBsxAohxAohxAohxAohxAohWsbzi1euXFnr6OiYolcB3r9/XxkeHq7+aRtXrB0dHZWBgYHJeSvg/3R1ddXd/DMYQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQozrB1PBf/vx40dx//Tp05R97RUrVhT3O3fuFPdt27YV9/b29uK+evXq4j4VnKwQQqwQQqwQQqwQQqwQQqwQQqwQwj3rHPfy5cvifu/evbrbw4cPi8++efNmIq80Jps3by7ug4ODxb3RHXEjo6Ojf/X8RDhZIYRYIYRYIYRYIYRYIYRYIYRYIYR71lmu0WdCb968WdzPnz9f3L9//17ca7VacZ8pr169mulXmHZOVgghVgghVgghVgghVgghVgjh6maWGx4eLu69vb3T9CbTb+vWrXW3HTt2TOObzA5OVgghVgghVgghVgghVgghVgghVgjhnnUMvn37Vtxv375d3Lu7u4t76dtqtrSU/4qWL19e3BcvXlzcv379WtwPHjxYd9uyZUvx2Z07dxb3devWFffSn33BggXFZ5uRkxVCiBVCiBVCiBVCiBVCiBVCiBVCuGetVCojIyPFfe/evcX96dOnxf3Fixfjfqf/6OzsLO5DQ0PFfdmyZcX98+fPxX3JkiV1t2q1WnyWyeVkhRBihRBihRBihRBihRBihRBihRBz5p51dHS07tbT01N8ttE96pUrV4p76fOqf6vRPWojS5cunZwXYco5WSGEWCGEWCGEWCGEWCGEWCGEWCFE09yzNvpM6vXr1+tufX19xWdXrVpV3I8ePVrcW1tbizuMhZMVQogVQogVQogVQogVQogVQjTN1c2zZ8+K+8mTJ+tu69evLz47MDBQ3BcuXFjcYTI4WSGEWCGEWCGEWCGEWCGEWCGEWCFE09yz9vf3T/jZXbt2FffSjz2E6eJkhRBihRBihRBihRBihRBihRBihRBNc89669atCT979+7d4r59+/bivn///uK+Zs2acb8T/C8nK4QQK4QQK4QQK4QQK4QQK4QQK4So1mq1Mf/irq6uWqPvoTtTqtVqcZ83b+r+v9To9+7t7S3uu3fvrrsNDQ0Vn924cWNx7+zsLO6NvHv3ru62adOm4rM+Bzx+XV1dlYGBgT/+x+xkhRBihRBihRBihRBihRBihRBihRBNc8966dKl4n727NlpepO5o62trbgfOHCguF+7dm0S36Y5uGeFJiBWCCFWCCFWCCFWCCFWCNE0Vze/f/8u7h8+fKi77du3r/jsyMhIcS99jKxSafxuzarRxxZv3LhR3I8cOTKZrxPB1Q00AbFCCLFCCLFCCLFCCLFCCLFCiKb5kY+Nvh1oe3t73e3169d/9bXfvn1b3H/+/FncT506VXfr7++f0DvNBo3u8J8/f17c5+I9a4mTFUKIFUKIFUKIFUKIFUKIFUKIFUI0zT3rTGr0YxcbOXToUN2t0T1rS0v5r/D06dPFvaenp7hfvny57nb16tXis0wuJyuEECuEECuEECuEECuEECuEECuEcM86C+zZs2fCz/769au4X7hwobgPDg4W9wcPHoz7ncZq7dq1U/Z7NyMnK4QQK4QQK4QQK4QQK4QQK4RwdTMLtLW11d2OHTtWfLbRj01s5P79+xN+dv78+cX98OHDxf3cuXMT/tpzkZMVQogVQogVQogVQogVQogVQogVQrhnnQVaW1vrbhcvXiw+++XLl+L+6NGj4v7x48fivmHDhrrbiRMnis8eP368uDM+TlYIIVYIIVYIIVYIIVYIIVYIIVYI4Z51llu0aFFx7+vrK+5Pnjwp7o8fPy7uZ86cqbs1ejcml5MVQogVQogVQogVQogVQogVQogVQrhnbXLd3d1/tTN7OFkhhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghRLVWq439F1er/6pUKv+cuteBOa+9Vqv940/DuGIFZo5/BkMIsUIIsUIIsUIIsUIIsUIIsUIIsUIIsUKIfwNCrweG+y92hgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The label is 5\n",
      "The prediction is 6\n"
     ]
    }
   ],
   "source": [
    "index_incorrect = numpy.where(array_correct == False)[0][0] # Where returns tuple\n",
    "plt.imshow(x_test[index_incorrect], cmap='Greys')\n",
    "plt.xticks([])\n",
    "plt.yticks([])\n",
    "plt.show()\n",
    "print(\"The label is\", y_test_cat[index_incorrect])\n",
    "print(\"The prediction is\", predictions[index_incorrect])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5b45140-b676-4042-8bf6-52d9d800de48",
   "metadata": {},
   "source": [
    "## Finding the accuracy of the model on the test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f141d692-f4b2-4c39-a643-121050b9258b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The model is correct 9538 times out of 10000\n",
      "The accuracy is 0.9538\n"
     ]
    }
   ],
   "source": [
    "num_correct = array_correct.sum() # True is 1 and False is 0.\n",
    "print(\"The model is correct\", num_correct, \"times out of\", len(y_test_cat))\n",
    "print(\"The accuracy is\", num_correct/len(y_test_cat))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8abcbf56-9b76-4799-977c-c94fe27a3505",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
